{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# KMeans\n",
    "## 1 聚类算法\n",
    "对于\"监督学习\"(supervised learning)，其训练样本是带有标记信息的，并且监督学习的目的是：对带有标记的数据集进行模型学习，从而便于对新的样本进行分类。而在“无监督学习”(unsupervised learning)中，训练样本的标记信息是未知的，目标是通过对无标记训练样本的学习来揭示数据的内在性质及规律，为进一步的数据分析提供基础。对于无监督学习，应用最广的便是\"聚类\"(clustering)。\n",
    "  “聚类算法”试图将数据集中的样本划分为若干个通常是不相交的子集，每个子集称为一个“簇”(cluster)，通过这样的划分，每个簇可能对应于一些潜在的概念或类别。\n",
    "  我们可以通过下面这个图来理解：\n",
    "![title](pic/22.jpg)\n",
    "上图是未做标记的样本集，通过他们的分布，我们很容易对上图中的样本做出以下几种划分。\n",
    "  当需要将其划分为两个簇时，即 k=2 时：\n",
    "![title](pic/23.jpg) ![title](pic/24.jpg)\n",
    "当需要将其划分为四个簇时，即 k=4 时：\n",
    "![title](pic/25.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 kmeans算法\n",
    "kmeans算法又名k均值算法。其算法思想大致为：先从样本集中随机选取 k 个样本作为簇中心，并计算所有样本与这 k 个“簇中心”的距离，对于每一个样本，将其划分到与其距离最近的“簇中心”所在的簇中，对于新的簇计算各个簇的新的“簇中心”。\n",
    "  根据以上描述，我们大致可以猜测到实现kmeans算法的主要三点：\n",
    "  \n",
    "  （1）簇个数 k 的选择\n",
    "  \n",
    "  （2）各个样本点到“簇中心”的距离\n",
    "  \n",
    "  （3）根据新划分的簇，更新“簇中心”\n",
    "### 2.1 kmeans算法要点\n",
    "（1） k 值的选择\n",
    "\n",
    "     k 的选择一般是按照实际需求进行决定，或在实现算法时直接给定 k 值。\n",
    "     \n",
    "（2） 距离的度量\n",
    "\n",
    "     给定样本\n",
    "     $x^{(i)}=\\left\\{x_{1}^{(i)}, x_{2}^{(i)}, \\ldots, x_{n}^{(i)},\\right\\} \\stackrel{(i)}{=}\\left\\{x_{1}^{(j)}, x_{2}^{(j)}, \\ldots, x_{n}^{(j)},\\right\\}$,其中i,j=1,2,...,m，表示样本数，n表示特征数  。距离的度量方法主要分为以下几种：\n",
    "     ![title](pic/26.png)\n",
    "\n",
    "(3） 更新“簇中心”\n",
    "\n",
    "     对于划分好的各个簇，计算各个簇中的样本点均值，将其均值作为新的簇中心。\n",
    "### 2.2 kmeans算法过程\n",
    "\n",
    "![title](pic/27.png)\n",
    "![title](pic/26.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as  np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadDataSet(filename):\n",
    "    dataMat = []\n",
    "    with open(filename) as fr:\n",
    "        for line in fr.readlines():\n",
    "            curline = line.strip().split()\n",
    "            fltline = list(map(float, curline))\n",
    "            dataMat.append(fltline)\n",
    "    return np.array(dataMat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataMat=loadDataSet('data/Kmeans/testSet.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def distEclud(vecA, vecB):\n",
    "    return np.sqrt(np.power(vecA - vecB, 2).sum())  # not sum(mat) but mat.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def randCent(dataSet, k):\n",
    "    n = np.shape(dataSet)[1]  \n",
    "    centroids = np.mat(np.zeros((k, n)))\n",
    "    for j in range(n):\n",
    "        minj = min(dataSet[:, j])\n",
    "        rangej = float(max(dataSet[:, j]) - minj)\n",
    "        centroids[:, j] = minj + rangej * np.random.rand(k, 1)\n",
    "    return centroids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "matrix([[ 2.0087042 , -0.75674176],\n",
       "        [-2.80963583,  4.9478131 ]])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "randCent(dataMat,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def KMeans(dataSet, k, distMeas=distEclud, createCent=randCent):\n",
    "    m = np.shape(dataSet)[0]\n",
    "    clusterAssment = np.mat(np.zeros((m, 2)))\n",
    "    centroids = randCent(dataSet, k)\n",
    "    clusterchanged = True\n",
    "    while clusterchanged:\n",
    "        clusterchanged = False\n",
    "        for i in range(m):\n",
    "            cluster_i = clusterAssment[i, 0];\n",
    "            dismax = np.inf\n",
    "            for j in range(k):\n",
    "                curdis = distEclud(centroids[j, :], dataSet[i, :])\n",
    "                if curdis < dismax:\n",
    "                    dismax = curdis\n",
    "                    clusterAssment[i, :] = j, dismax\n",
    "            if cluster_i != clusterAssment[i, 0]: clusterchanged = True\n",
    "        print(centroids)\n",
    "        for cent in range(k):\n",
    "            ptsInClust = dataSet[np.nonzero(clusterAssment[:, 0].A == cent)[0]]\n",
    "            centroids[cent, :] = np.mean(ptsInClust, axis=0)\n",
    "    return centroids, clusterAssment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 1.64738477 -3.89985624]\n",
      " [ 2.14955832  2.37770554]]\n",
      "[[-0.43856939 -2.95461287]\n",
      " [ 0.19944238  2.77665202]]\n",
      "[[-0.2897198  -2.83942545]\n",
      " [ 0.08249337  2.94802785]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(matrix([[-0.2897198 , -2.83942545],\n",
       "         [ 0.08249337,  2.94802785]]), matrix([[1.        , 2.06716812],\n",
       "         [1.        , 3.5681125 ],\n",
       "         [0.        , 5.39850778],\n",
       "         [0.        , 5.1167591 ],\n",
       "         [1.        , 0.89039257],\n",
       "         [1.        , 3.91557751],\n",
       "         [0.        , 0.8730819 ],\n",
       "         [0.        , 3.3862195 ],\n",
       "         [1.        , 2.91888366],\n",
       "         [1.        , 3.24808913],\n",
       "         [0.        , 3.64487896],\n",
       "         [0.        , 2.51060892],\n",
       "         [1.        , 4.12585863],\n",
       "         [1.        , 2.2058353 ],\n",
       "         [0.        , 2.56070545],\n",
       "         [0.        , 1.12895497],\n",
       "         [1.        , 3.07341157],\n",
       "         [1.        , 0.95792926],\n",
       "         [0.        , 3.27439238],\n",
       "         [0.        , 2.95877748],\n",
       "         [1.        , 2.25513093],\n",
       "         [1.        , 1.9098742 ],\n",
       "         [0.        , 2.6496711 ],\n",
       "         [0.        , 3.11424737],\n",
       "         [1.        , 1.93527435],\n",
       "         [1.        , 1.91077281],\n",
       "         [0.        , 2.98768309],\n",
       "         [0.        , 3.73567161],\n",
       "         [1.        , 2.21920342],\n",
       "         [1.        , 3.50771736],\n",
       "         [0.        , 1.70680771],\n",
       "         [0.        , 3.43375907],\n",
       "         [1.        , 1.97525631],\n",
       "         [1.        , 3.47410493],\n",
       "         [0.        , 5.08002556],\n",
       "         [0.        , 2.38231251],\n",
       "         [1.        , 2.87955238],\n",
       "         [1.        , 1.16528929],\n",
       "         [0.        , 3.13146755],\n",
       "         [0.        , 3.57227354],\n",
       "         [1.        , 2.22189947],\n",
       "         [1.        , 2.79163176],\n",
       "         [0.        , 3.67302958],\n",
       "         [0.        , 2.30135571],\n",
       "         [1.        , 2.25897037],\n",
       "         [1.        , 3.38296374],\n",
       "         [0.        , 3.51912833],\n",
       "         [0.        , 3.49617155],\n",
       "         [1.        , 3.46369756],\n",
       "         [1.        , 2.23199271],\n",
       "         [0.        , 2.42109414],\n",
       "         [0.        , 4.18103686],\n",
       "         [1.        , 4.30668597],\n",
       "         [1.        , 4.88624186],\n",
       "         [0.        , 2.94409918],\n",
       "         [0.        , 3.72216392],\n",
       "         [1.        , 2.60421872],\n",
       "         [1.        , 2.61411007],\n",
       "         [0.        , 2.18027864],\n",
       "         [0.        , 2.97324131],\n",
       "         [1.        , 2.84892887],\n",
       "         [1.        , 2.64579665],\n",
       "         [0.        , 3.93982644],\n",
       "         [0.        , 2.42170645],\n",
       "         [1.        , 3.41128037],\n",
       "         [1.        , 1.81064634],\n",
       "         [0.        , 2.58553874],\n",
       "         [0.        , 2.37619479],\n",
       "         [1.        , 2.2484467 ],\n",
       "         [1.        , 1.768216  ],\n",
       "         [0.        , 3.41129738],\n",
       "         [0.        , 4.35917943],\n",
       "         [1.        , 4.49889879],\n",
       "         [1.        , 2.22113752],\n",
       "         [0.        , 4.96632713],\n",
       "         [0.        , 2.59679251],\n",
       "         [1.        , 2.80323577],\n",
       "         [1.        , 3.05175738],\n",
       "         [0.        , 4.88863326],\n",
       "         [0.        , 4.61640218]]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "KMeans(np.array(dataMat),2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def biKmeans(dataSet, k, distMeas=distEclud):\n",
    "    m = np.shape(dataSet)[0]\n",
    "    clusterAssment = np.mat(np.zeros((m, 2)))\n",
    "    centroid0 = np.mean(dataSet, axis=0).tolist()[0]\n",
    "    cenList = [centroid0]\n",
    "    for j in range(m):\n",
    "        clusterAssment[j, 1] = distMeas(np.mat(centroid0), dataSet[j, :]) ** 2\n",
    "    while len(cenList) < k:\n",
    "        lowestSSE = np.inf\n",
    "        for i in range(len(cenList)):\n",
    "            ptscurrCluster = dataSet[np.nonzero(clusterAssment[:, 0].A == i)[0], :]\n",
    "            centroidMat, splitClustAss = KMeans(ptscurrCluster, 2)\n",
    "            ssesplit = splitClustAss[:, 1].sum()\n",
    "            ssenotsplit = clusterAssment[np.nonzero(clusterAssment[:, 0].A != i), 1].sum()\n",
    "            print(ssesplit, ssenotsplit)\n",
    "            if ssesplit + ssenotsplit < lowestSSE:\n",
    "                lowestSSE = ssenotsplit + ssesplit\n",
    "                bestnewCent = centroidMat\n",
    "                bestClustAss = splitClustAss.copy()\n",
    "                bestCentToSplit = i\n",
    "        bestClustAss[np.nonzero(bestClustAss[:, 0].A == 1)[0], 0] = len(cenList)\n",
    "        bestClustAss[np.nonzero(bestClustAss[:, 0].A == 0)[0], 0] = bestCentToSplit\n",
    "        print('the bestcenttosplit: ', bestCentToSplit)\n",
    "        print('len bestclustass: ', len(bestClustAss))\n",
    "        cenList[bestCentToSplit] = bestnewCent[0, :]\n",
    "        cenList.append(bestnewCent[1, :].tolist()[0])\n",
    "        clusterAssment[np.nonzero(clusterAssment[:, 0].A == bestCentToSplit)[0],\n",
    "        :] = bestClustAss  # reassign new clusters, and SSE\n",
    "    return np.mat(cenList), clusterAssment\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.32909744 -2.2394072 ]\n",
      " [-4.64394088 -3.03177339]]\n",
      "[[ 1.46190604  0.77760244]\n",
      " [-3.54775556 -1.53696152]]\n",
      "[[ 1.92304273  0.70709257]\n",
      " [-3.30703713 -0.97753032]]\n",
      "[[ 2.35560655  0.474688  ]\n",
      " [-3.10932625 -0.45950489]]\n",
      "[[ 2.62924541  0.26552329]\n",
      " [-2.97661844 -0.16775279]]\n",
      "[[ 2.71473038  0.18858278]\n",
      " [-2.9219568  -0.07998038]]\n",
      "248.1507528392998 0.0\n",
      "the bestcenttosplit:  0\n",
      "len bestclustass:  80\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(matrix([[matrix([[2.71473038, 0.18858278]]),\n",
       "          list([-2.9219568000000002, -0.07998037500000002])]], dtype=object),\n",
       " matrix([[0.        , 4.23040738],\n",
       "         [1.        , 3.54441323],\n",
       "         [0.        , 2.51093336],\n",
       "         [1.        , 4.10035377],\n",
       "         [0.        , 3.24316536],\n",
       "         [1.        , 1.7362298 ],\n",
       "         [0.        , 4.16075955],\n",
       "         [1.        , 1.73885412],\n",
       "         [0.        , 1.40701044],\n",
       "         [1.        , 3.27951404],\n",
       "         [0.        , 4.21260813],\n",
       "         [1.        , 3.02239548],\n",
       "         [0.        , 3.1701662 ],\n",
       "         [1.        , 3.12704603],\n",
       "         [0.        , 2.1186188 ],\n",
       "         [1.        , 4.63488064],\n",
       "         [0.        , 1.39036144],\n",
       "         [1.        , 4.03266264],\n",
       "         [0.        , 3.55303008],\n",
       "         [1.        , 2.22090339],\n",
       "         [0.        , 2.71302541],\n",
       "         [1.        , 2.86858385],\n",
       "         [0.        , 2.1593047 ],\n",
       "         [1.        , 2.74091587],\n",
       "         [0.        , 3.80919097],\n",
       "         [1.        , 2.61897662],\n",
       "         [0.        , 2.24625295],\n",
       "         [1.        , 3.31006662],\n",
       "         [0.        , 3.23145889],\n",
       "         [1.        , 0.8985653 ],\n",
       "         [0.        , 4.57094873],\n",
       "         [1.        , 2.30725376],\n",
       "         [0.        , 2.63296708],\n",
       "         [1.        , 1.87938582],\n",
       "         [0.        , 1.94176803],\n",
       "         [1.        , 3.43472889],\n",
       "         [0.        , 5.06951182],\n",
       "         [1.        , 3.10496698],\n",
       "         [0.        , 2.84974243],\n",
       "         [1.        , 3.30335224],\n",
       "         [0.        , 3.74872061],\n",
       "         [1.        , 3.01140421],\n",
       "         [0.        , 3.43586104],\n",
       "         [1.        , 4.22905405],\n",
       "         [0.        , 3.01587999],\n",
       "         [1.        , 3.37265594],\n",
       "         [0.        , 4.02149516],\n",
       "         [1.        , 2.49774371],\n",
       "         [0.        , 2.71916508],\n",
       "         [1.        , 4.48059931],\n",
       "         [0.        , 3.226235  ],\n",
       "         [1.        , 4.10492696],\n",
       "         [0.        , 5.06314252],\n",
       "         [1.        , 3.91150753],\n",
       "         [0.        , 3.45048873],\n",
       "         [1.        , 3.09539939],\n",
       "         [0.        , 1.78885675],\n",
       "         [1.        , 2.75259234],\n",
       "         [0.        , 3.4706773 ],\n",
       "         [1.        , 3.50135292],\n",
       "         [0.        , 2.30753653],\n",
       "         [1.        , 2.9861275 ],\n",
       "         [0.        , 4.20829999],\n",
       "         [1.        , 1.96469809],\n",
       "         [0.        , 3.84455058],\n",
       "         [1.        , 3.87931185],\n",
       "         [0.        , 2.78210484],\n",
       "         [1.        , 3.0353152 ],\n",
       "         [0.        , 2.99236808],\n",
       "         [1.        , 3.26293961],\n",
       "         [0.        , 3.82194406],\n",
       "         [1.        , 2.69242622],\n",
       "         [0.        , 1.7535225 ],\n",
       "         [1.        , 3.27005353],\n",
       "         [0.        , 3.37466029],\n",
       "         [1.        , 2.07372417],\n",
       "         [0.        , 2.85987519],\n",
       "         [1.        , 2.92903276],\n",
       "         [0.        , 2.63237797],\n",
       "         [1.        , 3.4568445 ]]))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "biKmeans(np.array(dataMat),2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
